from typing import Any
import openai
import time
import sys
import io
import re
from copy import deepcopy
import json
import numpy as np
from tqdm import trange, tqdm
import logging
import signal

# setup logging to print out to std and line number
logging.basicConfig(
    level=logging.DEBUG,
    format="%(asctime)s [%(levelname)s] %(message)s (%(filename)s:%(lineno)d)",
)
TIMEOUT_DURATION = 10

def extract_func_names_from_snippet(s):
    func_signatures = []
    for l in s.split("\n"):
        if l.startswith("def "):
            func_signatures.append(l.split(":")[0])
    return func_signatures



class timeout:
    def __init__(self, seconds=TIMEOUT_DURATION, error_message='Timeout'):
        self.seconds = seconds
        self.error_message = error_message

    def handle_timeout(self, signum, frame):
        raise TimeoutError(self.error_message)
    
    def __enter__(self):
        signal.signal(signal.SIGALRM, self.handle_timeout)
        signal.alarm(self.seconds)

    def __exit__(self, type, value, traceback):
        signal.alarm(0)

TEMPLATE = """In this task, we are going to provide you a coding challenge that is described after the word 'PROBLEM:', and the correct solution should be implemented as a function, starting with 'def ...'. Besides, after your code, you can also write two more statements to debug your code; we will execute the code and send the printed information back to you, and you can use the printed information to try again.

For example: 
PROBLEM: Write a function that returns the square of a number.
Your response can be:
def square(x):
    return x ** 2
print(square(6))

The above solution is correct and we will return the execution result '36' (the execution result of square(6)) to you. 

Now let's get started. Only write a function and at most two more statements. Only write the code and not anything else, since we will need to directly execute your code.
PROBLEM: {problem}
"""

NEW_MESSAGE = """Here's the output from your previous code: {output_value}. You passed {fraction:02} fraction of tests. Again, here's the PROBLEM:{problem}
Only write a function and at most two more statements for debugging. Only write the code and not anything else, since we will need to directly execute your code."""
GOLD_SPECIAL_TOKEN = "GOLD_CODE"

NUM_ITERATIONS = 5
NUM_RETRY = 3


def get_code_execution(content):
    global_dict = {}
    function_names = extract_func_names_from_snippet(content)
    try:
        buffer = io.StringIO()
        sys.stdout = buffer
        with timeout():
            exec(content, global_dict)
        sys.stdout = sys.__stdout__
        output_value = buffer.getvalue()
        return {
            "output_value": output_value,
            "global_dict": global_dict,
            "function_name": function_names
        }
    except KeyboardInterrupt:
        exit(0)
    except Exception as e:
        logging.error(e)

        return {
            "output_value": "Error: " + str(e),
            "global_dict": global_dict,
            "function_name": function_names
        }
    


def robust_query(messages, model, temperature):
    for num_retry in range(NUM_RETRY):
        try:
            response =  openai.ChatCompletion.create(
                model=model,
                messages=messages,
                temperature=temperature,
            )
            content = response.choices[0].message.content
            return content
        except KeyboardInterrupt:
            exit(0)
        except Exception as e:
            logging.error(e)
            logging.error(f"Retry {num_retry}...")
            time.sleep(10)
            continue


class Env:

    def __init__(self, problem_dict, num_iterations):
        self.gold_code = problem_dict["gold"]
        self.tests = problem_dict.get("tests", [])
        self.test_setup_code = problem_dict.get("test_setup_code", "")
        self.step_count = 0
        self.function_names = extract_func_names_from_snippet(self.gold_code)
        test = self.tests[0] if len(self.tests) > 0 else None
        self.problem_prompt = add_test_to_prompt(problem_dict["query"], self.function_names, test)
        self.num_iterations = num_iterations
    
    def reset(self):
        self.step_count = 0
    
    def extract_test_name(self):
        test = self.tests[0]
        if test.startswith("assert"):
            test = test.split("assert ")[1]
        
        test = test.strip()
        test = test.split("(")[0]
        return test
    
    def step_(self, code):
        if code is None:
            return {
                "output_value": "Error: empty code",
                "rewards": [0.0] * len(self.tests),
                "done": False,
                "info": {},
                "reward": 0.0,
            }

        exec_dict = get_code_execution(code)
        rewards = []

        output_value = exec_dict["output_value"]
        
        for test in self.tests:
            exec_dict = get_code_execution(code)
            global_dict = exec_dict["global_dict"]

            try:
                with timeout():
                    exec(test, global_dict)
                rewards.append(1.0)
            except KeyboardInterrupt:
                exit(0)
            except Exception as e:
                logging.error(e)
                rewards.append(0.0)
                continue
        
        self.step_count += 1
        done = self.step_count >= self.num_iterations or np.mean(rewards).tolist() == 1.0
        return {
            "output_value": output_value,
            "rewards": rewards,
            "done": done,
            "info": {},
            "reward": np.mean(rewards),
        }

    def step(self, code):
        if code == GOLD_SPECIAL_TOKEN:
            return self.step_(self.gold_code)
        else:
            return self.step_(code)


def add_test_to_prompt(query, function_names, test=None):
    problem_prompt = query + f"Your response should have the following function signature(s): {','.join(function_names)}. "
    if test is not None:
        problem_prompt += f"Additionally, your response should pass the following test: {test}. "
    return problem_prompt


toy_problem_dict = {
    "query": "Write a function that take the cube of a number",
    "tests": ["assert f(1) == 1", "assert f(2) == 8", "assert f(3) == 27"]
    }


class Policy:

    def __init__(self) -> None:
        pass

    def reset(self, problem_prompt):
        self.problem_prompt = problem_prompt
        self.messages = self.get_initial_messages()
    
    def get_initial_messages(self):
        return [
            {"role": "system", "content": "Be a coding assistant and implement the function required by the user and some debugging statements."}, 
            {"role": "user", "content": TEMPLATE.format(problem=self.problem_prompt)},
        ]
    
    def adapt_to_observation(self, observation):
        new_message_content = NEW_MESSAGE.format(
            output_value=observation["output_value"],
            fraction=np.mean(observation["rewards"]).tolist(),
            problem=self.problem_prompt,
        )
        new_message = {"role": "user", "content": new_message_content}
        self.messages.append(new_message)

    def action(self):
        pass

    @classmethod
    def extract_first_python_code(cls, action):
        if type(action) != str:
            return None
        # Regular expression pattern to match python code blocks in markdown
        pattern = r"```python\n(.*?)\n```"
        
        # Search for the pattern in the markdown string
        match = re.search(pattern, action, re.DOTALL)
        if match:
            # If a match is found, return the code. Otherwise, return None.
            return match.group(1)
        
        else:
            return action


class GoldPolicy(Policy):

    def action(self):
        return GOLD_SPECIAL_TOKEN
    

class ChatGPTPolicy(Policy):

    def __init__(self, model, temperature) -> None:
        self.model = model
        self.temperature = temperature

    def action(self):
        raw_action = robust_query(self.messages, self.model, self.temperature)
        self.messages.append({"role": "assistant", "content": raw_action})
        return Policy.extract_first_python_code(raw_action)

    def impute_messages(self, trajectory):
        messages = self.get_initial_messages()
        for t in trajectory:
            observation, action = t.get("observation"), t.get("action")
            if action is not None:
                messages.append({"role": "assistant", "content": action})
            if observation is not None:
                new_message_content = NEW_MESSAGE.format(
                    output_value=observation["output_value"],
                    fraction=np.mean(observation["rewards"]).tolist(),
                    problem=self.problem_prompt,
                )
            
                new_message = {"role": "user", "content": new_message_content}
                messages.append(new_message)
        return messages